package selector

import (
	"context"
	"errors"
	"sync"
	"sync/atomic"
)

type WeightedRoundRobin struct {
	mux       sync.Mutex
	nodes     atomic.Value
	nodeCount int   // 节点数量
	maxWeight int64 // 节点最大权重
	gcdWeight int64 // 节点权重的最大公约数
	lastIndex int   // 上一次选择的节点索引
}

type WeightedRoundRobinBuilder struct {
	w *WeightedRoundRobin
}

func NewWeightedRoundRobinBuilder() Builder {
	return &WeightedRoundRobin{}
}

func (w *WeightedRoundRobin) Build() Selector {
	return NewWeightedRoundRobin()
}

// NewWeightedRoundRobin 创建加权轮询实例
func NewWeightedRoundRobin() Selector {
	wrr := &WeightedRoundRobin{
		nodeCount: 0,
		maxWeight: 0,
		gcdWeight: 1,
		lastIndex: -1,
	}
	return wrr
}

/**
 *  添加所有的当前服务所有的节点
 */
func (w *WeightedRoundRobin) Add(nodes ...Node) error {
	if len(nodes) == 0 {
		return errors.New("params len 1 at least")
	}
	w.mux.Lock()
	defer w.mux.Unlock()
	//初始化
	w.nodes.Store(nodes)
	w.nodeCount = len(nodes)
	w.maxWeight = 0
	w.gcdWeight = 1
	w.lastIndex = -1
	return w.init()
}

// init 初始化节点信息
func (w *WeightedRoundRobin) init() error {
	n := w.nodes
	if n.Load() == nil {
		return errors.New("NODE_NOT_FOUND")
	}
	nodes := n.Load().([]Node)
	for _, node := range nodes {
		if node.InitialWeight() > w.maxWeight {
			w.maxWeight = node.InitialWeight()
		}
		w.gcdWeight = gcd(w.gcdWeight, node.InitialWeight())
	}
	return nil
}

// gcd 获取两个整数的最大公约数
func gcd(a, b int64) int64 {
	if b == 0 {
		return a
	}
	return gcd(b, a%b)
}

// Next 选择下一个节点
func (wrr *WeightedRoundRobin) Next(ctx context.Context) (selected Node, err error) {
	wrr.mux.Lock()
	defer wrr.mux.Unlock()
	n := wrr.nodes
	if n.Load() == nil {
		return nil, errors.New("NODE_NOT_FOUND")
	}
	nodes := n.Load().([]Node)
	if len(nodes) == 0 {
		return nil, errors.New("node len 0")
	}
	for {
		wrr.lastIndex = (wrr.lastIndex + 1) % wrr.nodeCount
		if wrr.lastIndex == 0 {
			wrr.maxWeight -= wrr.gcdWeight
			if wrr.maxWeight <= 0 {
				max, gerr := wrr.getMaxWeight()
				if gerr != nil {
					return nil, gerr
				}
				wrr.maxWeight = max
				if wrr.maxWeight == 0 {
					return nil, errors.New("wrr.maxWeight == 0")
				}
			}
		}

		node := nodes[wrr.lastIndex]
		if node.InitialWeight() >= wrr.maxWeight {
			peer, ok := FromPeerContext(ctx)
			if ok {
				peer.Node = node
			}
			return node, nil
		}
	}
}

// getMaxWeight 获取所有节点的最大权重
func (wrr *WeightedRoundRobin) getMaxWeight() (int64, error) {
	maxWeight := int64(0)
	n := wrr.nodes
	if n.Load() == nil {
		return 0, errors.New("NODE_NOT_FOUND")
	}
	nodes := n.Load().([]Node)
	for _, node := range nodes {
		if node.InitialWeight() > maxWeight {
			maxWeight = node.InitialWeight()
		}
	}
	return maxWeight, nil
}
